{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tutorial 03: weak imposition of Dirichlet BCs by a Lagrange multiplier (nonlinear problem)\n",
    "\n",
    "In this tutorial we solve the problem\n",
    "\n",
    "$$\\begin{align*}\n",
    "&\\min_{u} \\int_\\Omega \\left\\{ (1 + u^2)\\ |\\nabla u|^2 - u \\right\\} dx,\\\\\n",
    "&\\text{s.t. } u = g\\text{ on }\\Gamma = \\partial \\Omega\n",
    "\\end{align*}$$\n",
    "where $\\Omega$ is the unit ball in 2D.\n",
    "\n",
    "The optimality conditions result in the following nonlinear problem\n",
    "\n",
    "$$\\begin{align*}\n",
    "&\\int_\\Omega (1+u^2)\\ \\nabla u \\cdot \\nabla v dx + \\int_\\Omega u \\ |\\nabla u|^2 v dx = \\int_\\Omega v dx\\\\\n",
    "&\\text{s.t. } u = g\\text{ on }\\Gamma = \\partial \\Omega\n",
    "\\end{align*}$$\n",
    "\n",
    "\n",
    "We compare the following two cases:\n",
    "* **strong imposition of Dirichlet BCs**:\n",
    "the corresponding weak formulation is\n",
    "$$\n",
    "\\text{find } u \\in V_g \\text{ s.t. } \\int_\\Omega (1+u^2)\\ \\nabla u \\cdot \\nabla v dx + \\int_\\Omega u \\ |\\nabla u|^2 v dx = \\int_\\Omega v dx, \\quad \\forall v \\in V_0\\\\\n",
    "$$\n",
    "where\n",
    "$$\n",
    "V_g = \\{v \\in H^1(\\Omega): v|_\\Gamma = g\\},\\\\\n",
    "V_0 = \\{v \\in H^1(\\Omega): v|_\\Gamma = 0\\}.\\\\\n",
    "$$\n",
    "* **weak imposition of Dirichlet BCs**: this requires an introduction of a multiplier $\\lambda$ which is restricted to $\\Gamma$, and solves\n",
    "$$\n",
    "\\text{find } w, \\lambda \\in V \\times M \\text{ s.t. }\\\\\n",
    "\\begin{cases}\n",
    "\\int_\\Omega (1+u^2)\\ \\nabla u \\cdot \\nabla v dx + \\int_\\Omega u \\ |\\nabla u|^2 v dx + \\int_\\Gamma \\lambda v = \\int_\\Omega v, & \\forall v \\in V,\\\\\n",
    "\\int_\\Gamma w \\mu = \\int_\\Gamma g \\mu, & \\forall \\mu \\in M\n",
    "\\end{cases}\n",
    "$$\n",
    "where\n",
    "$$\n",
    "V = H^1(\\Omega),\\\\\n",
    "M = L^{2}(\\Gamma).\\\\\n",
    "$$\n",
    "\n",
    "This example is a prototypical case of problems containing subdomain/boundary restricted variables (the Lagrange multiplier, in this case)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from mpi4py import MPI\n",
    "from petsc4py import PETSc\n",
    "from ufl import derivative, grad, inner, Measure, replace, TestFunction, TrialFunction\n",
    "from dolfinx import DirichletBC, Function, FunctionSpace\n",
    "from dolfinx.fem import locate_dofs_topological\n",
    "from dolfinx.io import XDMFFile\n",
    "from dolfinx.plot import create_vtk_topology\n",
    "from multiphenicsx.fem import (apply_lifting, assemble_matrix, assemble_matrix_block, assemble_scalar,\n",
    "                               assemble_vector, assemble_vector_block, BlockVecSubVectorWrapper,\n",
    "                               create_vector, create_vector_block, create_matrix, create_matrix_block,\n",
    "                               DofMapRestriction, set_bc)\n",
    "import pyvista"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Mesh"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with XDMFFile(MPI.COMM_WORLD, \"data/circle.xdmf\", \"r\") as infile:\n",
    "    mesh = infile.read_mesh()\n",
    "    mesh.topology.create_connectivity_all()\n",
    "    subdomains = infile.read_meshtags(mesh, name=\"subdomains\")\n",
    "    boundaries = infile.read_meshtags(mesh, name=\"boundaries\")\n",
    "facets_Gamma = boundaries.indices[boundaries.values == 1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define associated measures\n",
    "dx = Measure(\"dx\")(subdomain_data=subdomains)\n",
    "ds = Measure(\"ds\")(subdomain_data=boundaries)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dolfinx_to_pyvista_mesh(mesh):\n",
    "    num_cells = mesh.topology.index_map(mesh.topology.dim).size_local\n",
    "    cell_entities = np.arange(num_cells, dtype=np.int32)\n",
    "    pyvista_cells, cell_types = create_vtk_topology(mesh, mesh.topology.dim, cell_entities)\n",
    "    grid = pyvista.UnstructuredGrid(pyvista_cells, cell_types, mesh.geometry.x)\n",
    "    return grid"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pyvista_mesh_plot(mesh):\n",
    "    grid = dolfinx_to_pyvista_mesh(mesh)\n",
    "    plotter = pyvista.PlotterITK()\n",
    "    plotter.add_mesh(grid)\n",
    "    plotter.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pyvista_mesh_plot(mesh)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Weak imposition of Dirichlet BCs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define a function space\n",
    "V = FunctionSpace(mesh, (\"Lagrange\", 2))\n",
    "M = V.clone()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define restrictions\n",
    "dofs_V = np.arange(0, V.dofmap.index_map.size_local + V.dofmap.index_map.num_ghosts)\n",
    "dofs_M_Gamma = locate_dofs_topological(M, boundaries.dim, facets_Gamma)\n",
    "restriction_V = DofMapRestriction(V.dofmap, dofs_V)\n",
    "restriction_M_Gamma = DofMapRestriction(M.dofmap, dofs_M_Gamma)\n",
    "restriction = [restriction_V, restriction_M_Gamma]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define trial and test functions, as well as solution\n",
    "(du, dl) = (TrialFunction(V), TrialFunction(M))\n",
    "(u, l) = (Function(V), Function(M))\n",
    "(v, m) = (TestFunction(V), TestFunction(M))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define problem block forms\n",
    "g = Function(V)\n",
    "g.interpolate(lambda x: np.sin(3 * x[0] + 1) * np.sin(3 * x[1] + 1))\n",
    "F = [inner((1 + u**2) * grad(u), grad(v)) * dx + u * v * inner(grad(u), grad(u)) * dx + l * v * ds - v * dx,\n",
    "     u * m * ds - g * m * ds]\n",
    "J = [[derivative(F[0], u, du), derivative(F[0], l, dl)],\n",
    "     [derivative(F[1], u, du), derivative(F[1], l, dl)]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Class for interfacing with the SNES\n",
    "class NonlinearLagrangeMultplierBlockProblem(object):\n",
    "    def __init__(self, F, J, solutions, bcs, P=None):\n",
    "        self._F = F\n",
    "        self._J = J\n",
    "        self._obj_vec = create_vector_block(F, restriction)\n",
    "        self._solutions = solutions\n",
    "        self._bcs = bcs\n",
    "        self._P = P\n",
    "\n",
    "    def update_solutions(self, x):\n",
    "        x.ghostUpdate(addv=PETSc.InsertMode.INSERT, mode=PETSc.ScatterMode.FORWARD)\n",
    "        with BlockVecSubVectorWrapper(x, [V.dofmap, M.dofmap], restriction) as x_wrapper:\n",
    "            for x_wrapper_local, sub_solution in zip(x_wrapper, self._solutions):\n",
    "                with sub_solution.vector.localForm() as sub_solution_local:\n",
    "                    sub_solution_local[:] = x_wrapper_local\n",
    "\n",
    "    def obj(self, snes, x):\n",
    "        self.F(snes, x, self._obj_vec)\n",
    "        return self._obj_vec.norm()\n",
    "\n",
    "    def F(self, snes, x, F_vec):\n",
    "        self.update_solutions(x)\n",
    "        with F_vec.localForm() as F_vec_local:\n",
    "            F_vec_local.set(0.0)\n",
    "        assemble_vector_block(F_vec, self._F, self._J, self._bcs, x0=x,\n",
    "                              scale=-1.0, restriction=restriction, restriction_x0=restriction)\n",
    "\n",
    "    def J(self, snes, x, J_mat, P_mat):\n",
    "        J_mat.zeroEntries()\n",
    "        assemble_matrix_block(J_mat, self._J, self._bcs, diagonal=1.0,\n",
    "                              restriction=(restriction, restriction))\n",
    "        J_mat.assemble()\n",
    "        if self._P is not None:\n",
    "            P_mat.zeroEntries()\n",
    "            assemble_matrix_block(P_mat, self._P, self._bcs, diagonal=1.0,\n",
    "                                  restriction=(restriction, restriction))\n",
    "            P_mat.assemble()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create problem\n",
    "problem = NonlinearLagrangeMultplierBlockProblem(F, J, (u, l), [])\n",
    "F_vec = create_vector_block(F, restriction=restriction)\n",
    "J_mat = create_matrix_block(J, restriction=(restriction, restriction))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Solve\n",
    "solution = create_vector_block(F, restriction=restriction)\n",
    "snes = PETSc.SNES().create(mesh.mpi_comm())\n",
    "snes.setTolerances(max_it=20)\n",
    "snes.getKSP().setType(\"preonly\")\n",
    "snes.getKSP().getPC().setType(\"lu\")\n",
    "snes.getKSP().getPC().setFactorSolverType(\"mumps\")\n",
    "snes.setObjective(problem.obj)\n",
    "snes.setFunction(problem.F, F_vec)\n",
    "snes.setJacobian(problem.J, J=J_mat, P=None)\n",
    "snes.setMonitor(lambda _, it, residual: print(it, residual))\n",
    "snes.solve(None, solution)\n",
    "problem.update_solutions(solution)  # TODO can this be safely removed?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pyvista_scalar_field_plot(mesh, scalar_field, name):\n",
    "    grid = dolfinx_to_pyvista_mesh(mesh)\n",
    "    grid.point_arrays[name] = scalar_field.compute_point_values()\n",
    "    grid.set_active_scalars(name)\n",
    "    plotter = pyvista.PlotterITK()\n",
    "    plotter.add_mesh(grid)\n",
    "    plotter.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pyvista_scalar_field_plot(mesh, u, \"u\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pyvista_scalar_field_plot(mesh, l, \"l\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Strong imposition of Dirichlet BCs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Class for interfacing with the SNES\n",
    "class NonlinearLagrangeMultplierProblem(object):\n",
    "    def __init__(self, F, J, solution, bcs, P=None):\n",
    "        self._F = F\n",
    "        self._J = J\n",
    "        self._obj_vec = create_vector(F)\n",
    "        self._solution = solution\n",
    "        self._bcs = bcs\n",
    "        self._P = P\n",
    "\n",
    "    def update_solution(self, x):\n",
    "        x.ghostUpdate(addv=PETSc.InsertMode.INSERT, mode=PETSc.ScatterMode.FORWARD)\n",
    "        with x.localForm() as _x, self._solution.vector.localForm() as _solution:\n",
    "            _solution[:] = _x\n",
    "\n",
    "    def obj(self, snes, x):\n",
    "        self.F(snes, x, self._obj_vec)\n",
    "        return self._obj_vec.norm()\n",
    "\n",
    "    def F(self, snes, x, F_vec):\n",
    "        self.update_solution(x)\n",
    "        with F_vec.localForm() as F_vec_local:\n",
    "            F_vec_local.set(0.0)\n",
    "        assemble_vector(F_vec, self._F)\n",
    "        apply_lifting(F_vec, [self._J], [self._bcs], x0=[x], scale=-1.0)\n",
    "        F_vec.ghostUpdate(addv=PETSc.InsertMode.ADD, mode=PETSc.ScatterMode.REVERSE)\n",
    "        set_bc(F_vec, self._bcs, x, -1.0)\n",
    "\n",
    "    def J(self, snes, x, J_mat, P_mat):\n",
    "        J_mat.zeroEntries()\n",
    "        assemble_matrix(J_mat, self._J, self._bcs, diagonal=1.0)\n",
    "        J_mat.assemble()\n",
    "        if self._P is not None:\n",
    "            P_mat.zeroEntries()\n",
    "            assemble_matrix(P_mat, self._P, self._bcs, diagonal=1.0)\n",
    "            P_mat.assemble()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define problem block forms\n",
    "u_ex = Function(V)\n",
    "F_ex = replace(F[0], {u: u_ex, l: 0})\n",
    "J_ex = derivative(F_ex, u_ex, du)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define Dirichlet BC object on Gamma\n",
    "dofs_V_Gamma = locate_dofs_topological(V, boundaries.dim, facets_Gamma)\n",
    "bc_ex = [DirichletBC(g, dofs_V_Gamma)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create problem\n",
    "problem_ex = NonlinearLagrangeMultplierProblem(F_ex, J_ex, u_ex, bc_ex)\n",
    "F_ex_vec = create_vector(F_ex)\n",
    "J_ex_mat = create_matrix(J_ex)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Solve\n",
    "solution_ex = create_vector(F_ex)\n",
    "snes = PETSc.SNES().create(mesh.mpi_comm())\n",
    "snes.setTolerances(max_it=20)\n",
    "snes.getKSP().setType(\"preonly\")\n",
    "snes.getKSP().getPC().setType(\"lu\")\n",
    "snes.getKSP().getPC().setFactorSolverType(\"mumps\")\n",
    "snes.setObjective(problem_ex.obj)\n",
    "snes.setFunction(problem_ex.F, F_ex_vec)\n",
    "snes.setJacobian(problem_ex.J, J=J_ex_mat, P=None)\n",
    "snes.setMonitor(lambda _, it, residual: print(it, residual))\n",
    "snes.solve(None, solution_ex)\n",
    "problem_ex.update_solution(solution_ex)  # TODO can this be safely removed?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pyvista_scalar_field_plot(mesh, u_ex, \"u\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Comparison and error computation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "u_ex_norm = np.sqrt(mesh.mpi_comm().allreduce(assemble_scalar(inner(grad(u_ex), grad(u_ex)) * dx), op=MPI.SUM))\n",
    "err_norm = np.sqrt(mesh.mpi_comm().allreduce(assemble_scalar(inner(grad(u_ex - u), grad(u_ex - u)) * dx), op=MPI.SUM))\n",
    "print(\"Relative error is equal to\", err_norm / u_ex_norm)\n",
    "assert np.isclose(err_norm / u_ex_norm, 0., atol=1.e-9)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython"
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
